Porting Flask to FastAPI for ML Model Serving
One of the common difficulties I’ve seen with new Data Scientists coming into industry is getting their machine learning models interacting with production systems - that is, making the jump from their research environment (usually developing in a Jupyter notebook or similar) to a broader code ecosystem. Putting the model behind a REST API ensures that we can serve predictions to consumers on the fly, independent of the compute resources, scale, and even the languages chosen for the rest of the production ecosystem.
Flask is a very popular framework for building REST APIs in Python, with nearly half of Python coders reporting its use in 2019. In particular, Flask is useful for serving ML models, where simplicity & flexibility are more desirable than the “batteries included” all-in-one functionality of other frameworks geared more towards general web development. However, since Flask’s introduction there have been a number of developments to Python’s performance, type annotation behavior, and asynchronous programming interface, allowing us to build a faster, more robust, and more responsive API - here, we’ll learn how to migrate to the newer FastAPI framework to take advantage of these advances.
This post will assume some familiarity with Flask and HTTP requests, though we’ll go through the app’s construction in Flask first. Complete dockerized code for this post may be found here. For the purposes of this article, we’ll be covering:
- setting up the FastAPI app and running the async server with
uvicorn
andgunicorn
- route specification and endpoints
- data validation
- asynchronous endpoints
Things we won’t be covering:
- unit testing your API - but you can design tests using the starlette
TestClient
that behave pretty much like Flask’s - deployment, since that’s very organization-specific - but the app is easily Dockerized and runs on production-ready servers with
gunicorn
, and we’ve included an example configuration in the full code sample.
a simple Flask API
First, we’ll need a model.
The demo code linked above provides a script to train a simple Naive Bayes classifier on a subset of the “20 newsgroups” dataset available in scikit-learn
(using the rec.sport.hockey
, sci.space
, and talk.politics.misc
newsgroups), but feel free to build and pickle
your own text classifier on the same - any ML package with a reasonable API will work fine in this framework.
Let’s start with a simple, single-file Flask app serving this model. First, the header content:
import os
import pickle
from flask import Flask, jsonify, request
import numpy as np
app = Flask(__name__)
with open(os.getenv("MODEL_PATH"), "rb") as rf:
clf = pickle.load(rf)
@app.route("/healthcheck", methods=["GET"])
def healthcheck():
msg = (
"this sentence is already halfway over, "
"and still hasn't said anything at all"
)
return jsonify({"message": msg})
in which we set up our Flask app
object and load the model into memory on app startup - the actual prediction methods will reference this model object as a closure, avoiding unnecessary file I/O for each model call.
(We’ll set aside finding a better model loading paradigm than pickle
for now, as that will depend on the specifics of the deployment scheme and model involved.)
Including a “health check” endpoint returning a simple “I’m ok” status without computation to GET
requests is useful both for deployment monitoring, and for ensuring that your own development environment is set up correctly.
The above should produce a working (albeit slightly silly) API in your environment - test it out by running from your shell
$ export FLASK_APP=<path to your app file>.py
$ flask run
(or use the dockerized invocation from the example code).
prediction endpoints
Next, let’s add some actual prediction functionality:
@app.route("/predict", methods=["POST"])
def predict():
samples = request.get_json()["samples"]
data = np.array([sample["text"] for sample in samples])
probas = clf.predict_proba(data)
predictions = probas.argmax(axis=1)
return jsonify(
{
"predictions": (
np.tile(clf.classes_, (len(predictions), 1))[
np.arange(len(predictions)), predictions
].tolist()
),
"probabilities": probas[np.arange(len(predictions)), predictions].tolist()
}
)
This is the typical design pattern we use for serving predictions - clients POST
input data to the endpoint and receive predictions back in a fixed format, with the model itself called via closure.
This expects a request body of the form
{
"samples": [
{
"text": "this is a text about hockey"
}
]
}
So we’re passing it an array of samples
, each of which contains a key-value set of the text input data we’re expecting (for models with multiple distinct features, pass each feature as a key-value pair).
This ensures we’re explicitly accessing the data in the form we want, plays nicely with pandas
and other tabular data tools, and mirrors the shape of individual datapoints we might expect from a production data source.
Our predict()
method unpacks this into an array, passes it through the scikit-learn
classifier pipeline object, and returns the predicted label (mapped to the class name) and its corresponding probability.
Note that we could’ve simply called the classifier’s predict
method rather than mapping the predict_proba
probability outputs back to class labels, but some simple numpy
algebra should generally be substantially faster than running the inference a second time.
@app.route("/predict/<label>", methods=["POST"])
def predict_label(label):
samples = request.get_json()["samples"]
data = np.array([sample["text"] for sample in samples])
probas = clf.predict_proba(data)
target_idx = clf.classes_.tolist().index(label)
return jsonify({"label": label, "probabilities": probas[:, target_idx].tolist()})
Similarly, we can request the specific probability of a given label for each sample - this type of parameterization is readily done through the path itself. And there we have it!
This API will work for serving predictions (try it out yourself!)… but it’s entirely likely it will fall over if you try to put it into production.
Notably, it lacks any data validation - any missing, misnamed, or mistyped information in the incoming request results in an unhandled exception, returning a 500 error and a singularly unhelpful response from Flask
(along with some rather intimidating HTML in debug mode), while potentially firing alerts into your monitoring systems and waking up your DevOps.
Frequently, error handling in Flask
ends up with a brittle, tangled jumble of try-catches and protected dict
access.
Better approaches will use a package like pydantic
or marshmallow
to achieve more programmatic data validation.
Fortunately, FastAPI
includes pydantic
validation out of the box.
enter FastAPI
FastAPI is a modern Python web framework designed to:
- provide a lightweight microframework with an intuitive, Flask-like routing system
- take advantage of the type annotation support in Python 3.6+ for data validation and editor support
- utilize improvements to Python’s async support and the development of the ASGI specification to make for easier asynchronous APIs
- automatically generate useful API documentation using OpenAPI and JSON Schema
Under the hood, FastAPI is using pydantic for data validation and starlette for its web tooling, making it ludicrously fast compared to frameworks like Flask and giving comparable performance to high-speed web APIs in Node or Go.
Fortunately, since FastAPI explicitly drew on Flask to inspire its route specification, transitioning to using it is quite quick - let’s start porting functionality of our serving app over. For the corresponding header content:
import os
import pickle
from fastapi import FastAPI
import numpy as np
app = FastAPI()
with open(os.getenv("MODEL_PATH"), "rb") as rf:
clf = pickle.load(rf)
@app.get("/healthcheck")
def healthcheck():
msg = (
"this sentence is already halfway over, "
"and still hasn't said anything at all"
)
return {"message": msg}
Easy so far - we’ve only had to make minor semantic changes:
- instantiating the
Flask(__name__)
top-level object becomes instantiating aFastAPI()
object - routes are still specified with a decorator, but with the HTTP method in the decorator itself rather than an argument -
@app.route(..., methods=["GET"])
becomes@app.get(...)
We can invoke this much like flask’s app launch, although it doesn’t include a built-in webserver - instead, we’ll directly launch the uvicorn server,
$ uvicorn --reload <path to app file>:app
which gives us a simple single-worker setup (or, again, just use the dockerized invocation in the example code).
data validation
For specifying the prediction endpoint, we do have to introduce one major change to our thinking.
In the simple validations mentioned above, things like try-catches and protecting dict
s are essentially trying to slap down things we don’t want with our data.
What if, instead, we simply specified what we do want and let the app handle the rest?
This is exactly what validation tools like marshmallow
or pydantic
do - for pydantic
in FastAPI, we simply specify the schema using Python’s new (3.6+) type annotations and pass it as an argument to the route function.
FastAPI knows what to do with the validation thanks to python’s type annotation, meaning we need only very naturally specify what we expect for inputs and let the rest happen under the hood.
For our predict()
endpoint, this looks like
from typing import List
from pydantic import BaseModel
class TextSample(BaseModel):
text: str
class RequestBody(BaseModel):
samples: List[Sample]
We simply specify the expected data shape (using base Python type annotations here, though pydantic
supports extended type checks for things like string/email validations or array dimension & norms) into a child class of pydantic.BaseModel
. While pydantic
builds its type checking functionality into this class, we can still use it like an ordinary Python object - subclassing or adding additional functions works as expected.
For example, we could build the functionality to unpack the contents of our samples into an array for inference into the class, like
class RequestBody(BaseModel):
samples: List[Sample]
def to_array(self):
return [sample.text for sample in self.samples]
replacing the list comprehension used above and guaranteeing correct array formatting. In the endpoint, we pass the input data as a function argument:
@app.post("/predict")
def predict(body: RequestBody):
data = np.array(body.to_array())
probas = clf.predict_proba(data)
predictions = probas.argmax(axis=1)
return {
"predictions": (
np.tile(clf.classes_, (len(predictions), 1))[
np.arange(len(predictions)), predictions
].tolist()
),
"probabilities": probas[np.arange(len(predictions)), predictions].tolist(),
}
FastAPI will handle this intelligently, finding route parameters by name first, then packing request bodies (for POST
requests) or query parameters (for GET
) into the function arguments.
The resulting data fields or methods can then be accessed through typical attribute/method syntax.
In the case of malformed input data, pydantic
will raise a validation error - FastAPI handles this internally, returning a 422 error code with a JSON body containing useful information about the error.
enumerated values & path parameters
We can also use enumerated values in our data validation - for example, in the predict_label()
endpoint we can handle valid target names with
from enum import Enum
class ResponseValues(str, Enum):
hockey = "rec.sport.hockey"
space = "sci.space"
politics = "talk.politics.misc"
Passing this to validate the route parameter will cleanly handle errors for bad target names, which would choke trying to find an index of the corresponding value in clf.classes_
.
In the endpoint, we then have
@app.post("/predict/{label}")
def predict_label(label: ResponseValues, body: RequestBody):
data = np.array(body.to_array())
probas = clf.predict_proba(data)
target_idx = clf.classes_.tolist().index(label.value)
return {"label": label.value, "probabilities": probas[:, target_idx].tolist()}
response models and documentation
So far, we’ve only added data validation on our inputs, but FastAPI allows us to declare a schema for the output as well - we define
class ResponseBody(BaseModel):
predictions: List[str]
probabilities: List[float]
class LabelResponseBody(BaseModel):
label: str
probabilities: List[float]
and replace the route decorators with
@app.post("/predict", response_model=ResponseBody)
...
@app.post("/predict/{label}", response_model=LabelResponseBody)
...
It may seem strange to add data validation to our outputs - after all, if what we’re returning doesn’t conform to the schema, that indicates a deeper issue in our code.
However, we can use these schemas to restrict exactly what data to return to an external user (e.g., by removing sensitive fields from an internal message).
More importantly for us, these schemas get incorporated into FastAPI’s auto-generated documentation - when the API is running, hitting the {api url}/docs
or {api url}/redoc
endpoints will load OpenAPI-generated documentation detailing the available endpoints and the expected structure of their inputs and outputs (which are derived from our schemas).
We can even add annotations to the API and its endpoints:
- individual routes can be tagged for categorization (useful for API versioning)
- route function docstrings will be pulled into the API doc for endpoint description
- the
FastAPI
object itself can take title, description, and version keyword arguments, which get populated into the documentation
asynchronous endpoints
We still haven’t touched on one of the most powerful aspects of FastAPI - its clean handling of asynchronous code. Frankly, this isn’t uncommon among data scientists and ML engineers - so much model training and inference is processor-limited that asynchronous code just doesn’t come up all that much, compared to (for example) web development where it’s much more common. For a deeper dive on this, check out Python’s own async documentation, though I actually find that FastAPI’s own explanation is much more intuitive for grasping the differences between concurrent and parallel code.
In short, the idea of asynchronous execution (or concurrency, if you prefer) regards your process firing work off to an external resource, like a request to an external API or datastore. In synchronous code, the process blocks until that work is completed - for a slow request, that means the entire process is sitting idle until it receives a return value. Asynchronous execution allows the process to switch contexts and work on something unrelated until it’s signaled that the requested work has completed, at which point it resumes. (In contrast, parallel code execution would have multiple lines of work all executing that potentially-blocking code independently of each other)
In machine learning, execution is typically processor-bound - that is, the work is consistently fully subscribing processing capacity across one or multiple CPU cores.
Async execution isn’t particularly helpful in this case, as there aren’t really situations where work can be put down by the processor to wait for a result while executing something else (in contrast, there are ample scenarios where ML work can be computed in parallel).
However, we can still envision situations where our model serving API might be waiting on an external resource rather than doing computational heavy lifting on its own.
For example, suppose that our API must request information from a database or in-memory cache based on its computation, or that our API is a lightweight middleman performing validation or preprocessing before farming work off to a separate tensorflow-serving
API running on a GPU instance - properly handling asynchronous processing can give our work a significant performance boost for little cost.
Historically, async work in Python has been nontrivial (though its API has rapidly improved since Python 3.4) particularly with Flask.
Essentially, Flask (on most WSGI servers) is blocking by default - work triggered by a request to a particular endpoint will hold the server entirely until that request is completed.
Instead, Flask (or rather, the WSGI server running it, like gunicorn
or uWSGI
) achieve scaling by running multiple worker instances of the app in parallel, such that requests can be farmed to other workers while one is busy.
Within a single worker, asynchronous work can be wrapped in a blocking call (the route function itself is still blocking), threaded (in newer versions of Flask), or farmed to a queue manager like Celery
- but there isn’t a single consistent story where routes can cleanly handle asynchronous requests without additional tooling.
async with FastAPI
In contrast, FastAPI is designed from the ground up to run asynchronously - thanks to its underlying starlette
ASGI framework, route functions default to running within an asynchronous event loop.
With a good ASGI server (FastAPI is designed to couple to uvicorn
, running on top of uvloop
) this can get us performance on par with fast asynchronous webservers in Go or Node, without losing the benefits of Python’s broader machine learning ecosystem.
In contrast to messing with threads or Celery queues to achieve asynchronous execution in Flask, running an endpoint asynchronously is dead simple in FastAPI - we simply declare the route function as asynchronous (with async def
) and we’re ready to go!
We can even do this if the route function isn’t conventionally asynchronous - that is, we don’t have any awaitable calls (like if the endpoint is running inference against an ML model).
In fact, unless the endpoint is specifically performing a blocking IO operation (to a database, for example), it’s better to declare the function with async def
(as blocking functions are actually punted to an external threadpool and then awaited anyhow).
For our ML prediction functions above, we can declare the endpoints with async def
, though that doesn’t really make any interesting changes to our code.
But what if we needed to do something truly asynchronous, like requesting (and waiting for) a resource from an external API?
Unfortunately, our conventional requests
package in Python is blocking, so we can’t use it to make HTTP requests asynchronously - instead, we’ll use the request functionality in the excellent aiohttp
package.
async requests
First, we’ll need to set up a client session - this will keep a persistent pool running to await requests from, rather than creating a new session for each request (which is actually what requests
does if called like the typical requests.get
, requests.post
, etc.).
We’ll put this at the top level of the app, such that any route function can call it via closure:
app = FastAPI()
...
client_session = aiohttp.ClientSession()
We’ll also need to ensure that this session closes down properly - fortunately, FastAPI gives us an easy decorator to declare these operations:
@app.on_event("shutdown")
async def cleanup():
await client_session.close()
This will execute anything called within the function (awaiting the clean shutdown of the aiohttp
client session, here) when the FastAPI app shuts down.
For the external request, we wrap an awaitable call in the route function:
@app.get("/cat-facts", response_model=TextSample)
async def cat_facts():
url = "https://cat-fact.herokuapp.com/facts/random"
async with client_session.get(url) as resp:
response = await resp.json()
return response
placing the request within an asynchronous context block, then awaiting a parseable response.
Here we’ve made use of our response models to restrict return values - the response from the “cat facts” API returns a lot of additional metadata about the fact, but we only want to return the fact text.
Rather than fiddling with the response before returning, we can simply reuse our existing TextSample
schema to pack it into the response and trust pydantic
to take care of the filtering, so our response looks like
{
"text": "In an average year, cat owners in the United States spend over $2 billion on cat food."
}
and that’s it!
We can use this construction for any asynchronous callouts to external resources that we may need, like retrieving data from a datastore or firing inference jobs to tensorflow-serving
running on GPU resources.
wrapping up
In this post, we’ve walked through a common, simple layout for standing your machine learning models up behind a REST API, enabling predictions on the fly in a consumable way that should interface cleanly with a variety of code environments. While Flask is an extremely common framework for these tasks, we can take advantage of improvements to Python’s type checking and asynchronous support by migrating to the newer FastAPI framework - fortunately, porting to FastAPI from Flask is straightforward!